from utils.my_tools import tokenizer
from utils.my_tools import generate
from utils.my_tools import output_path__ppo_model, device, output_path__gen_model
from utils.my_tools import TestModel
from utils.time_tools import with_timer
from bdtime import tt
import torch


use_ppo = 1
if use_ppo:
    model_ppo = torch.load(output_path__ppo_model)
    # model_ppo = model_ppo.to(device)
    # model_ppo.eval()
    # len(tokenizer.encoder)
    model_gen = model_ppo.model_gen
else:
    model_gen = torch.load(output_path__gen_model)


# 随机一批数据
# _, input_ids, _ = tokenizer.get_batch_data(prefix=True)
#
# #切分成question和answer
# split = [i.index(tokenizer.encoder['=']) + 1 for i in input_ids]
# question = [input_ids[i][:split[i]] for i in range(len(input_ids))]
# answer = [input_ids[i][split[i]:] for i in range(len(input_ids))]

batch_size = 100
label, _question, attention_mask, real_answer = TestModel.get_question(prefix=True, batch_size=batch_size, ret_real_answer=True)

# 根据question生成predict
with with_timer(f"生成predict, batch_size: {batch_size}", tt) as wt:
    input_ids = [torch.LongTensor(i).unsqueeze(0).to(device) for i in _question]
    _predict_qa = [generate(model_gen, i) for i in input_ids]

    # 裁剪,只要生成的部分
    _predict = [p[0].tolist()[len(q):] for p, q in zip(_predict_qa, _question)]

# tt.sleep(3)


# 解码成文本
# question = [tokenizer.decode(i) for i in question]
# answer = [tokenizer.decode(i) for i in real_answer]
# predict = [tokenizer.decode(i) for i in predict]

# from utils.my_tools import show_qap
# show_qap(_question, real_answer, _predict, end=10, skip_spacial_symbols=True)


flag__test_cls = 1
if flag__test_cls:
    with with_timer(f"test_cls, batch_size: {batch_size}", tt) as wt:
        from utils.my_tools import output_path__cls_model, test_predict_cls
        # qa_token = [qa[0] for qa in _predict_qa]
        # qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
        qa_token = [torch.cat((q, a)) for q, a in list(zip(_question, real_answer))]

        qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
        # tokenizer.decode(qa_input_ids, '')

        qa_input_ids = torch.LongTensor(qa_input_ids).to(device)
        qa_attention_mask = torch.LongTensor(qa_attention_mask).to(device)

        model_cls = torch.load(output_path__cls_model)
        test_predict_cls(model_cls, qa_input_ids, qa_attention_mask, label)


question = tokenizer.decode(_question, '', skip_spacial_symbols=True)
answer = tokenizer.decode(real_answer, '', skip_spacial_symbols=True)
predict = tokenizer.decode(_predict, '', skip_spacial_symbols=True)

with with_timer(f"计算accuracy, batch_size: {batch_size}", tt) as wt:
    acc = 0
    show_times = 5
    for i, q, a, p in zip(list(range(len(question))), question, answer, predict):
        if i < show_times:
            print(a == p, '--- q, a, p ---', q, a, p)
        from utils.my_tools import test_model_cls
        if a == p:
            acc += 1

    print('--- accuracy:', round(acc / len(question), 3), f"total_types: {tokenizer.total_types}")


